{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST example with 3-conv. layer network\n",
    "\n",
    "This example demonstrates the usage of `FastaiLRFinder` with a 3-conv. layer network on the MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_pwd = \"data\"\n",
    "batch_size= 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
    "\n",
    "trainset = MNIST(mnist_pwd, train=True, download=True, transform=transform)\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "testset = MNIST(mnist_pwd, train=False, download=True, transform=transform)\n",
    "testloader = DataLoader(testset, batch_size=batch_size * 2, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ignite.engine import create_supervised_trainer, create_supervised_evaluator\n",
    "from ignite.metrics import Loss, Accuracy\n",
    "from ignite.contrib.handlers import FastaiLRFinder, ProgressBar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loss (fastai)\n",
    "\n",
    "This learning rate test range follows the same procedure used by fastai. The model is trained for `num_iter` iterations while the learning rate is increased from its initial value specified by the optimizer algorithm to `end_lr`. The increase can be linear (`step_mode=\"linear\"`) or exponential (`step_mode=\"exp\"`); linear provides good results for small ranges while exponential is recommended for larger ranges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "criterion = nn.NLLLoss()\n",
    "model = Net()\n",
    "model.to(device)  # Move model before creating optimizer\n",
    "optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e8da65177214eab9b765283023eb9de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e2b83b9cd4ba4bf2b34b306e6aef3915",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be3d08f1fe424d97b950542d9777d196",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "55e465e37a3d464081e62c9eafe0037c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f278da2d4a5f41709dff48b168c170a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eab71be3b1c74f2f851c0b14dbcd4971",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8aaed7e7cf654eb2bee927befa08b971",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ed399388dd74ab29004ca1fb7d9dd44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83bf9d6b011f4ec99a5b0c5c1e2d6f13",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f5e69329c9945fe8b0548f38fd4347e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'acc': 0.9021, 'loss': 0.3838102577209473}\n"
     ]
    }
   ],
   "source": [
    "trainer = create_supervised_trainer(model, optimizer, criterion, device=device)\n",
    "ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {\"batch loss\": x})\n",
    "\n",
    "lr_finder = FastaiLRFinder()\n",
    "to_save={'model': model, 'optimizer': optimizer}\n",
    "with lr_finder.attach(trainer, to_save, diverge_th=1.5) as trainer_with_lr_finder:\n",
    "    trainer_with_lr_finder.run(trainloader)\n",
    "    \n",
    "trainer.run(trainloader, max_epochs=10)\n",
    "\n",
    "evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=device)\n",
    "evaluator.run(testloader)\n",
    "\n",
    "print(evaluator.state.metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU9b3/8ddnJhtJIBASEAgQFkEQWSQC7su1bm1Fr1hLFatVuVRt7W1/Xe69bW/vtba93WutCypaW9G2atVSrW0VS1UoBpAdlH2XsCeB7J/fHxkoYhICzMnJzLyfj8c8mDnnO+e88RjeOXPOnGPujoiIpK5I2AFERCRcKgIRkRSnIhARSXEqAhGRFKciEBFJcSoCEZEUlxZ2gGNVUFDgxcXFYccQEUko8+bN2+HuhU3NS7giKC4uprS0NOwYIiIJxczWNzdPHw2JiKQ4FYGISIpTEYiIpDgVgYhIilMRiIikOBWBiEiKS7jTR4/X3v21bNy9n2jESIsYkYgBsGd/DXsP1GI0TouaEYlAxIxoxA79eXD6wefWzHTsn+u02AszSI9EyEyPkJkWwcyaiigiEoqUKYI3Vu3gjunzw45BVnqErjmZZKZHGgvpYKEc8TwzLUJWerTxkRahQ0b00OsO6VE6pDdOy8lMo2NWOh2z0uiUlUbXnEw6Z6erbESk1VKmCEb37cJDk0bT0ODUu1Pf0HhDnrwO6eR1SAegwZ36hsY/Dx93cPrB5w1HTD9ymQCH3+7H3amtd6rr6tldWcOOihpq6hqoj73vyHXV1TsV1XWUlVdTXddAVW09B2rrOVBTT3Vdw1H/rhlpEXIz00iLGOnRCOlRIys9SqdYYeRmpTX+mdn4OicjSnZmGrmZaeRkppGb2VgwORn/nJaRpk8RRZJVyhTBSXlZnJR3UtgxTlhDg1NT38CBmnr219ZTWV1HeVUt+6rq2Heglp0VNbxfXsX+6npq6xuorXdq6xvYX1NPRXUt2/ZVUb69jorY+2rrW3eHuoxohJxYQeRmNhZJx6x0cjNjxRKbnpt1cH46XbLT6ZKTQZfsDDpnp5MeVZmItEcpUwTJIhIxsiKNHxF1OcFluTvVdY0lUVndWA7//POIaTV17K+upyI2raKqcY9l7Y5KyqvqqKiupaq25b2VjplpdM5Jp0t2Yzl0zc2gMDeTrrmNrw/upeRkRunUIZ387AzyOqQfOp4jIsFQEaQwMzt03CE/J+OEl1db3xDbQ6lj74Fa9uyvZff+Gvbsr2FX5T+f795fy67KGt57v7zxY7L65gskIy3CgMJc+hfkUNgx84OP3Ey6dcwkPyeDNO1tiBw3FYHETXo0QufsDDpnZ9C7le9xd8qr69hdWfOBPZG9BxrLYtu+Kt57v5xlW/ex491qyqvrPrQMMyjIzaR/QQ4Du+UyoDCXgd1y6VeQQ8/OHYhqj0KkRSoCCZWZ0SkrnU5Z6a0af6Cmnh0V1Wwvr6asvJqyisY/t+45wJodlcxYtJW9B2oPjc+IRujTNZvirjn0K8imuCCHfgU59C/IpXunTJ1dJYKKQBJMh4wovfOz6Z2f3eR8d2dnZQ2rt1ewbmcla3ZUsm5HJet27GfWe2XUHHbWVdecDIb1yuO0XnkM65XH4JM60ic/W3sQknJUBJJUzIyC3EwKcjMZ27/rB+Y1NDhb91WxtqySVdvLWbplH4s37+WNVTsOnfqblR5hRFFnxvbL54x++Zzepws5mfoxkeSm/8MlZUQiRq/OHejVuQPnnFxwaHpVbT0rt5Wz8v1ylm/dR+m63dw3cxUNr0FaxDhvUCHjR/bk4iHdVQqSlPR/taS8rPQoI3p3ZkTvzoemlVfVMn/DHt5ctYMZC7fw2ortdEiP8pGh3Zl0Zl9K+nbR8QVJGubeui8UtRclJSWuW1VKW2pocErX7+aFdzYfOhg9oiiPW8/tz+XDTtKpq5IQzGyeu5c0OU9FINJ6+2vqeHb+Zqa9sZa1Oyrp1bkDN59dzHVn9KZjK898EgmDikAkzhoanFdXbOfhv69h7tpddMxM47MXDuCWc/qRmRYNO57Ih7RUBIHt05pZbzObaWbLzGypmd3VxJjrzWyRmS02s7fMbERQeUTiKRIxPjK0O7/9tzN58c6zGTegK9//00ou/cks3l63K+x4IsckyA8364AvuftQYBxwh5kNPWLMWuB8dz8NuBuYGmAekUAML+rMwzeW8MRnxuDAJ6fOYeqs1STa3rakrsCKwN23uvv82PNyYDnQ64gxb7n77tjLOUBRUHlEgnbeoEL+8LlzuGRod77z0gomPTqXTbv3hx1L5Kja5HQHMysGRgH/aGHYLcDLzbx/spmVmllpWVlZ/AOKxEmnrHTuv/507rl6GAs27Oayn/6dh/62mqra+rCjiTQr8IPFZpYL/A24x92fa2bMhcD9wDnuvrOl5elgsSSKjbv2840XlvD6yjJ65mXx80+NYnTf/LBjSYoK5WBxbMXpwLPAky2UwHDgEWD80UpAJJH0zs/m8ZvHMP3WsWSkRbhp2tss2bw37FgiHxLkWUMGPAosd/cfNzOmD/AcMMnd3w0qi0iYzhpYwPTbxtGpQzo3TpvL6rKKsCOJfECQewRnA5OAi8zsndjjCjObYmZTYmO+CXQF7o/N12c+kpR6du7Ar28dS8Tg5sfeZmdFddiRRA7RF8pE2tCCDbv55NQ5nNqzE9NvG0dWur58Jm0jtGMEIvJBo/p04afXjWTBxj18etpc9u6vPfqbRAKmIhBpY5ef1qOxDDbs4V8feJPNew6EHUlSnIpAJATjR/biiVvGsL28mpumzWVflfYMJDwqApGQjOvflYduGM3aHZXc8eR8ausbjv4mkQCoCERCdNbAAu65ehh/f28H33lpedhxJEXpDmUiIbvujD4s31rOY2+uY3TfLnxseM+wI0mK0R6BSDvwn1cM4fQ+nfnqM4v0hTNpcyoCkXYgIy3CL64/nfS0CF97dhENDYn1/R5JbCoCkXaiR14H/vPyIby9bjfPzNsUdhxJISoCkXZkwugixhTn852Xl7OrsibsOJIiVAQi7UgkYnz76mFUVNXxrReXhh1HUoSKQKSdGdS9I5+76GReXLiFlxZvDTuOpAAVgUg7dPuFAxhelMfXn19CWbmuVCrBUhGItEPp0Qg/unYEFdX6iEiCpyIQaadO7t6Rz104kD8u3srMFdvDjiNJTEUg0o5NPr8/A7vl8o0XlnCgpj7sOJKkVAQi7VhmWpR7rhrGpt0H+MXMVWHHkSSlIhBp58b278qVI3ryyBtr2La3Kuw4koSCvHl9bzObaWbLzGypmd3VxBgzs3vNbJWZLTKz04PKI5LIvnzpYOobnJ/+9d2wo0gSCnKPoA74krsPBcYBd5jZ0CPGXA6cHHtMBh4IMI9Iwuqdn82kccX8tnQj771fHnYcSTKBFYG7b3X3+bHn5cByoNcRw8YDT3ijOUBnM+sRVCaRRHbnRQPJyUjjZ6++F3YUSTJtcozAzIqBUcA/jpjVC9h42OtNfLgsMLPJZlZqZqVlZWVBxRRp1/JzMpg4tg8vL9nGpt37w44jSSTwIjCzXOBZ4Avuvu94luHuU929xN1LCgsL4xtQJIHcdFYxBjz25rqwo0gSCbQIzCydxhJ40t2fa2LIZqD3Ya+LYtNEpAk9O3fgY8N78PTcDew9oBveS3wEedaQAY8Cy939x80MexG4MXb20Dhgr7vrKlsiLbj13P5U1tTz9NwNYUeRJBHkHsHZwCTgIjN7J/a4wsymmNmU2JiXgDXAKuBh4PYA84gkhWG98jizf1cee3MdNXUNYceRJBDYzevd/Q3AjjLGgTuCyiCSrCaf15+bH3+bPy7ewtWjisKOIwlO3ywWSUDnDypkYLdcHp61lsbfp0SOn4pAJAFFIsat5/Rj2dZ9zF69M+w4kuBUBCIJ6qpRvSjIzWDq39eEHUUSnIpAJEFlpUe58cxiXl9ZpstOyAlREYgksBvG9SUrPcIjf18bdhRJYCoCkQSWn5PBhNFF/H7BZraX6xLVcnxUBCIJ7pZz+lPb0MCvZ68PO4okKBWBSILrV5DDeScX8rt5m6hv0KmkcuxUBCJJ4NqSIrbureKt1TvCjiIJSEUgkgQuHtKdTllp/K50U9hRJAGpCESSQFZ6lPEje/HK0m26KqkcMxWBSJKYMLqI6roGZizaEnYUSTAqApEkMbwoj1NO6siTczbo+kNyTFQEIknCzJh0Zl+Wbd1H6frdYceRBKIiEEkiV4/qRaesNH751rqwo0gCURGIJJHsjDQ+UdKbPy3Zxvv79E1jaR0VgUiSmXRmX+rdeXKOvmksraMiEEkyfbvmcOHgbkyfu4Hquvqw40gCCPLm9dPMbLuZLWlmfp6Z/cHMFprZUjO7OagsIqnm02cVs6OihpcXbws7iiSAIPcIHgcua2H+HcAydx8BXAD8yMwyAswjkjLOHVhA/4IcHtdBY2mFwIrA3WcBu1oaAnQ0MwNyY2PrgsojkkoiEePGM/vyzsY9LNy4J+w40s6FeYzgPmAIsAVYDNzl7g1NDTSzyWZWamalZWVlbZlRJGFdM7qInIwoT+jy1HIUYRbBpcA7QE9gJHCfmXVqaqC7T3X3EncvKSwsbMuMIgmrY1Y6V43qxYxFW9i7X9cfkuaFWQQ3A895o1XAWuCUEPOIJJ2JY/pQXdfA8+9sDjuKtGNhFsEG4F8AzKw7MBhYE2IekaQzrFcew4vymP4PXX9Imhfk6aNPAbOBwWa2ycxuMbMpZjYlNuRu4CwzWwy8CnzV3XVXDZE4mzimDyvfL2f+Bh00lqalBbVgd594lPlbgEuCWr+INPr4iJ58e8Yynpq7gdF9u4QdR9ohfbNYJMnlZqYx/uBBY920RpqgIhBJAZ8a04eq2gZe0EFjaYKKQCQFDOuVx2m9dNBYmqYiEEkRE8f0YcW2chbom8ZyBBWBSIq4cmRPsjOi/FrfNJYjqAhEUkRuZhqfPKMPLyzcwvqdlWHHkXZERSCSQqac35+0iHHfa6vCjiLtiIpAJIV065TF9WP78tyCzdorkENUBCIp5uBewQOvrw47irQTKgKRFNOtUxZXjezFiwu3UFmtW4CIikAkJV1bUsT+mnpeWrw17CjSDqgIRFLQ6L5d6F+Qw+/mbQo7irQDKgKRFGRmXDO6iLlrd+mgsagIRFLVNacXETF4RnsFKa9VRWBmA8wsM/b8AjP7vJl1DjaaiATppLwsLhjcjafmbqCqtj7sOBKi1u4RPAvUm9lAYCrQG5geWCoRaRO3ndufHRU1PDtfewWprLVF0ODudcDVwM/d/ctAj+BiiUhbGNc/nxG9O/PwrDXUN+iqpKmqtUVQa2YTgU8DM2LT0oOJJCJtxcyYcl5/1u3czytLt4UdR0LS2iK4GTgTuMfd15pZP+BXLb3BzKaZ2XYzW9LCmAvM7B0zW2pmf2t9bBGJl0tOPYnirtk89LfVuldBimpVEbj7Mnf/vLs/ZWZdgI7u/n9HedvjwGXNzYwdbL4fuNLdTwWubWVmEYmjaMS47bz+LNy0lzlrdoUdR0LQ2rOGXjezTmaWD8wHHjazH7f0HnefBbT0f9WngOfcfUNs/PZWZhaROLvm9CIKcjN4aJauP5SKWvvRUJ677wP+FXjC3ccCF5/gugcBXWIlM8/MbmxuoJlNNrNSMystKys7wdWKyJGy0qPcdFYxr68sY/nWfWHHkTbW2iJIM7MewCf458HiE5UGjAY+ClwKfMPMBjU10N2nunuJu5cUFhbGafUicrhJ44rJzojyk7+8q2MFKaa1RfC/wCvAand/28z6A++d4Lo3Aa+4e6W77wBmASNOcJkicpzystP53EUn8+dl7/PoG2vDjiNtqLUHi3/n7sPd/bOx12vc/ZoTXPcLwDlmlmZm2cBYYPkJLlNETsCU8/tz6and+e7LK3hr9Y6w40gbae3B4iIz+33sdNDtZvasmRUd5T1PAbOBwWa2ycxuMbMpZjYFwN2XA38CFgFzgUfcvdlTTUUkeGbGD68dQd+u2Xzt2cXU1TeEHUnaQGs/GnoMeBHoGXv8ITatWe4+0d17uHu6uxe5+6Pu/qC7P3jYmB+4+1B3H+buPz3ev4SIxE/HrHS+etkpbNi1nxmLdL+CVNDaIih098fcvS72eBzQUVuRJPWRId0Z1D2XX8xcRYMuPZH0WlsEO83sBjOLxh43ADuDDCYi4YlEjDsuHMh72yv48zJdeiLZtbYIPkPjqaPbgK3ABOCmgDKJSDvwseE9Ke6azb2vaq8g2bX2rKH17n6luxe6ezd3vwo40bOGRKQdi0aMuy4+mWVb9/HyEu0VJLMTuUPZF+OWQkTapStH9GJQ91x+9JeVOoMoiZ1IEVjcUohIuxSNGF+6ZDBryip5bsHmsONIQE6kCPShoUgKuGRod0YU5fHz197TXkGSarEIzKzczPY18Sin8fsEIpLkzIzPXjCQjbsO8CfdvCYptVgE7t7R3Ts18ejo7mltFVJEwvWRod3pV5DDQ39bowvSJaET+WhIRFJENGLcem4/Fm/ey+w1+gpRslERiEirHLx5zdRZa8KOInGmIhCRVslKj/LpMxtvXrNyW3nYcSSOVAQi0mo3jOtLh/So9gqSjIpARFqtS04G153Rmxfe2czWvQfCjiNxoiIQkWNyyzn9aHDnsTfXhR1F4kRFICLHpHd+Nh8f0ZMnZq9j8x7tFSQDFYGIHLMvXzoYd/juS7q7bDJQEYjIMSvqks2/nT+AGYu2MnftrrDjyAkKrAjMbFrs/sYt3ofYzM4wszozmxBUFhGJv8+eP4CeeVn81+8XU1FdF3YcOQFB7hE8DlzW0gAziwL/B/w5wBwiEoAOGVG+P2EEa3ZUctdTC6jXzWsSVmBF4O6zgKPtM34OeBbYHlQOEQnOOScX8N8fH8qrK7bzwz+vDDuOHKfQjhGYWS/gauCBVoydbGalZlZaVlYWfDgRabUbzyzmmtOLeOTva9i+ryrsOHIcwjxY/FPgq+5+1Aucu/tUdy9x95LCwsI2iCYix+Lz/zKQugbnV3PWhx1FjkOYRVACPG1m64AJwP1mdlWIeUTkOPXtmsPFQ7rz5D82UFVbH3YcOUahFYG793P3YncvBp4Bbnf358PKIyIn5jNn92NXZQ3P65aWCSfI00efAmYDg81sk5ndYmZTzGxKUOsUkfCM65/P0B6deOSNtTqDKMEEdpcxd594DGNvCiqHiLQNM+P2Cwdw5/QF/HHxVq4cobvZJgp9s1hE4uaKYT0Y3L0jP/vru9orSCAqAhGJm0jE+MLFJ7O6rJI/LNwSdhxpJRWBiMTVpaeexJAenfjhn1dSXlUbdhxpBRWBiMRVJGL87/hT2bLnAF9/fgnu+oiovVMRiEjcnVGcz79fPIgX3tnC7+ZtCjuOHIWKQEQCcfuFAxnXP5+7/7CMffqIqF1TEYhIIKIR4+sfHUp5dR2/1qUn2jUVgYgEZlivPM4bVMi0N9bq0hPtmIpARAJ1+wUD2FFRw29LN4YdRZqhIhCRQI3tl8/pfTrz4OurOVCjvYL2SEUgIoEyM752+RC27K3ivpnvhR1HmqAiEJHAjemXzzWnFzF11hpWba8IO44cQUUgIm3iP644hQ7pUb75gr5k1t6oCESkTRTkZvLlSwfz1uqdvLJ0W9hx5DAqAhFpMxPH9GFw947c89JynU7ajqgIRKTNpEUjfONjQ9m46wDT3lwbdhyJURGISJs65+QCLh7SnV+8tort+6rCjiOoCEQkBP/10SHU1Dfwg1dWhh1FCPaexdPMbLuZLWlm/vVmtsjMFpvZW2Y2IqgsItK+9CvI4eaz+/HM/E0s3rQ37DgpL8g9gseBy1qYvxY4391PA+4GpgaYRUTamTsvGkh+dgbffHEJNXUNYcdJaYEVgbvPAna1MP8td98dezkHKAoqi4i0P52y0vnmx4eyYMMevvS7hTToHsehSQs7QMwtwMvNzTSzycBkgD59+rRVJhEJ2PiRvdi6t4rvvbyCrjkZfOvKU8OOlJJCLwIzu5DGIjinuTHuPpXYR0clJSX6tUEkiUw5fwBl5dU8+sZazhlYwMVDu4cdKeWEetaQmQ0HHgHGu/vOMLOISHi+etkpnHJSR/7r+cXsPaC7mbW10IrAzPoAzwGT3P3dsHKISPgy0iJ8f8JwysqrueePy8KOk3IC+2jIzJ4CLgAKzGwT8N9AOoC7Pwh8E+gK3G9mAHXuXhJUHhFp34YXdWbK+QO4//XV9O2awx0XDgw7UsoIrAjcfeJR5t8K3BrU+kUk8XzpksFs3VvFD15ZSXrUmHzegLAjpYTQDxaLiBwUjRg/mDCc2voGvvPSCtKjEW4+u1/YsZKeikBE2pW0aISfXDeSunrnf/6wjLRohEnj+oYdK6npWkMi0u6kRyPcO3EUFw/pxjeeX8Jv3t4QdqSkpiIQkXYpIy3CL64/nfMHFfK15xbz7LxNYUdKWioCEWm3MtOiPDRpNGcPKODLzyzkhXc2hx0pKakIRKRdy0qP8vCNJYzpl88Xf7uQlxZvDTtS0lERiEi71yEjyqOfPoNRvTtz19MLdOnqOFMRiEhCyMlM45FPl9A1J5O7nl7A/pq6sCMlDRWBiCSMztkZ/Pi6EazdWcndM3QpinhREYhIQjlrQAGTz+vPU3M38n9/WoG7Lkh8ovSFMhFJOF+59BTKq+p44PXV7D1Qy7fHDyMSsbBjJSwVgYgknGjEuOeqYXTKSufBv62mQ3qUr390CLELWMoxUhGISEIyM7562WCqaut59I21FHbMZMr5ukjd8VARiEjCMjO++bGh7KysOXS7y2tLeocdK+GoCEQkoUUixo+uHcHuyhq+9txiumRn6HaXx0hnDYlIwstIi/DgpNGc2rMTd0yfz4xFW8KOlFBUBCKSFHIz03j85jGc1iuPO6cv4Cd/eVenlraSikBEkkZ+TgZP3jaWCaOL+Nmr7/HoG2vDjpQQAisCM5tmZtvNbEkz883M7jWzVWa2yMxODyqLiKSOzLQoP5gwnCtOO4nvvLScWe+WhR2p3Qtyj+Bx4LIW5l8OnBx7TAYeCDCLiKQQM+MHE0YwqHtH7pw+n3nrd4cd6YS4Ozc/Npfflm4MZPmBFYG7zwJ2tTBkPPCEN5oDdDazHkHlEZHUkpOZxsM3lpCfk8HEh+ck9AHkuWt3MXNlGQ0NwRzzCPMYQS/g8HrbFJv2IWY22cxKzay0rEy7eSLSOr3zs3nu9rMZUdR4APkXM1cl5AHkX85eR16HdMaPbPKfyBOWEAeL3X2qu5e4e0lhYWHYcUQkgeTnZPDrW8cyfmRPfvDKSr7yzCKqauvDjtVqm/cc4JWl7/PJMb3pkBENZB1hfqFsM3D4VwCLYtNEROIqMy3KT68bSXHXHH726nss2bKPn08cxcBuuWFHO6pfz1mPuzNpXN/A1hHmHsGLwI2xs4fGAXvdXfegE5FAmBn//pFBTLuphPf3VfHxn7/BW6t3hB2rRdv3VTH9Hxu4ZOhJFHXJDmw9QZ4++hQwGxhsZpvM7BYzm2JmU2JDXgLWAKuAh4Hbg8oiInLQRad05+W7zqVPfja3/rKU0nUtndMSnvoG5wu/eYfqunq+dMmgQNdliXbgpKSkxEtLS8OOISIJrqy8musems328mp+fetYRvbuHHakD7j31ff48V/e5fsThvOJOFxIz8zmuXtJU/MS4mCxiEi8FXbMZPpt48jPyeDGR//B0i17w450yJw1O/npX9/l6lG9uHZ0UeDrUxGISMo6KS+L6beNJTczjUmPzuXV5e+HHYmdFdXc9fQCirvmcPdVw9rkZjsqAhFJaUVdspl+2zgKcjO45Zel3PHkfOat3xXYl7dacqCmni/85h1276/l558aRW5m25zYqfsRiEjKKy7IYcbnzuWhv63mvpmr+OPirfTq3IGvXX4KHxveo01+K1+2ZR+ff3oBq8sq+O7Vp3Fqz7zA13mQDhaLiBymvKqWvyx7n8feXMfizXv5l1O6ccHgQvoV5DKufz5p0fh/kDJv/W6uf2QOHbPS+cknRnLOyQVxX0dLB4tVBCIiTairb2Dam2v52V/fo7Km8ZvIA7vl8pVLB/ORod3jtpewuqyCCQ+8RV6HdH475Uy6dcyKy3KPpCIQETlODQ3O9vJq3l63i5/89V3WlFXy+YsG8sVLBp/wsiur67jsZ7PYX13Pc7efRd+uOXFI3DSdPioicpwiEeOkvCw+PqInf/7CeXyipIh7X1sVl5vePPD6ajbuOsADN4wOtASORgeLRURaKS0a4TtXn8a+A3XcPWMZ2RlRJo7pc1zL2rhrP1P/voarRvZkTL/8OCc9NtojEBE5BmnRCD+bOJILBhfyH88tZvo/NhzXcr738goiBl+57JQ4Jzx2KgIRkWOUmRbloUmjueiUbvzn7xfzX79fzO7KmmbH19Q18NqK99m85wANDc7dM5bxx8Vb+ez5A+nZuUMbJm+aPhoSETkOmWlRHrjhdL738gqemL2eGYu28v8uGcSnxvalorqOmSu2U1lTx579tTw5Zz1b9lYRMehXkMPqskpuOquYOy8aGPZfA9BZQyIiJ2zltnK+9eJSZq/ZSZ/8bLbtq6KmruHQ/NF9u3DrOf1YumUfMxZt4fqxfbn13H5t8kW1g3T6qIhIwNydlxZvY9qbazm1Zyf+9fQieuZlEYkYXXMy2vQf/aa0VAT6aEhEJA7MjI8O78FHh/cIO8ox08FiEZEUpyIQEUlxKgIRkRQXaBGY2WVmttLMVpnZ15qY38fMZprZAjNbZGZXBJlHREQ+LMib10eBXwCXA0OBiWY29IhhXwd+6+6jgE8C9weVR0REmhbkHsEYYJW7r3H3GuBpYPwRYxzoFHueB2wJMI+IiDQhyCLoBWw87PWm2LTDfQu4wcw2AS8Bn2tqQWY22cxKzay0rKwsiKwiIikr7IPFE4HH3b0IuAL4lZl9KJO7T3X3EncvKSwsbPOQIiLJLMgvlG0Geh/2uig27XC3AJcBuPtsM8sCCoDtzS103rx5O8xsfROz8oC9rcjVmnEtjWluXlPTm5pWAOw4yvqD1Nr/TkEsR9vo6OK1fY53WfHaRvHaPvbLdHkAAAV7SURBVE1NT+WfoWN535Hj+jY70t0DedBYMmuAfkAGsBA49YgxLwM3xZ4PofEYgR3n+qbGa1xLY5qb19T0ZqaVBvXfPJ7/nYJYjrZR222fsLdRvLZPU9NT+Wcontvo8EdgHw25ex1wJ/AKsJzGs4OWmtn/mtmVsWFfAm4zs4XAUzSWwvFe/OgPcRzX0pjm5jU1vbWZ2lK8Mh3PcrSNji6eecLcRvHaPq1ZV1sL82foWN7X6uUn3EXnEp2ZlXozF36S9kHbqH3T9om/sA8Wp6KpYQeQo9I2at+0feJMewQiIilOewQiIilORSAikuJUBCIiKU5F0I6Y2RAze9DMnjGzz4adRz7MzK4ys4fN7DdmdknYeeSDzKy/mT1qZs+EnSWRqAjixMymmdl2M1tyxPQWL8V9OHdf7u5TgE8AZweZNxXFaRs97+63AVOA64LMm2ritH3WuPstwSZNPjprKE7M7DygAnjC3YfFpkWBd4GP0HjRvbdpvL5SFPjuEYv4jLtvj33Z7rPAr9x9elvlTwXx2kax9/0IeNLd57dR/KQX5+3zjLtPaKvsiU43r48Td59lZsVHTD50KW4AM3saGO/u3wU+1sxyXgReNLM/AiqCOIrHNjIzA74HvKwSiK94/QzJsdNHQ8FqzaW4DzGzC8zsXjN7iMbLckvwjmkb0Xip9IuBCWY2JchgAhz7z1BXM3sQGGVm/xF0uGShPYJ2xN1fB14POYa0wN3vBe4NO4c0zd130nj8Ro6B9giC1ZpLcUu4tI3aN22fNqAiCNbbwMlm1s/MMmi8L/OLIWeSD9I2at+0fdqAiiBOzOwpYDYw2Mw2mdktzV2KO8ycqUzbqH3T9gmPTh8VEUlx2iMQEUlxKgIRkRSnIhARSXEqAhGRFKciEBFJcSoCEZEUpyKQpGFmFW28vkfMbGgbr/MLZpbdluuU5KfvEUjSMLMKd8+N4/LSYl9oajOxq5uauzc0M38dUOLuO9oylyQ37RFIUjOzQjN71szejj3Ojk0fY2azzWyBmb1lZoNj028ysxfN7DXg1dgVYV+P3TVuhZk9GfvHmtj0ktjzCjO7x8wWmtkcM+semz4g9nqxmX27qb0WMyuO3XjlCWAJ0NvMHjCzUjNbamb/Exv3eaAnMNPMZsamXRL7e8w3s9+ZWdyKUFKIu+uhR1I8gIompk0Hzok97wMsjz3vBKTFnl8MPBt7fhONlzrOj72+ANhL48XOIjReAuHg8l6n8bdzAAc+Hnv+feDrseczgImx51OayVgMNADjDpt2cP3R2HqGx16vAwpizwuAWUBO7PVXgW+GvR30SLyHLkMtye5iYGjsl3iATrHfmvOAX5rZyTT+I55+2Hv+4u67Dns91903AZjZOzT+w/3GEeupofEffYB5NN5RC+BM4KrY8+nAD5vJud7d5xz2+hNmNpnGS8X3AIYCi454z7jY9Ddjf78MGotK5JioCCTZRWj8Tbvq8Ilmdh8w092vjt0V6/XDZlcesYzqw57X0/TPTa27+1HGtOTQOs2sH/D/gDPcfbeZPQ5kNfEeo7G0Jh7jukQ+QMcIJNn9mca7igFgZiNjT/P453Xtbwpw/XOAa2LPP9nK93SisRj2xo41XH7YvHKg42HLPtvMBgKYWY6ZDTrxyJJqVASSTLJjly8++Pgi8HmgxMwWmdky/nn3qu8D3zWzBQS7Z/wF4ItmtggYSOPxhha5+0JgAbCCxo+T3jxs9lTgT2Y2093LaCyxp2LLnw2cEt/4kgp0+qhIgGLn/B9wdzezT9J44Hh82LlEDqdjBCLBGg3cFzvldA/wmZDziHyI9ghERFKcjhGIiKQ4FYGISIpTEYiIpDgVgYhIilMRiIikOBWBiEiK+/+r97OgVfa+AQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "lr_finder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.02880642644595004"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_finder.lr_suggestion()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now setup suggested learning rate to the optimizer and we can train the model with optimal learning rate.\n",
    "\n",
    "*Note, that model and optimizer were restored to their initial states when `FastaiLRFinder` finished.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4583f23d189f47608d2a96c837bbfbd9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b886a03f924c49b585c3b560605eda9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40055607a67641efaf1070fc5b52aa9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34ae5ed614c74548863410d435a76c2c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97d01e750fa34d6c8c8e383b30fd26ab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99d408f0aa714d51a3c2a1c3302a4f6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "018670132a6e4b92957674fc414c8513",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8d5cfe6ee134919912f5648dcc38605",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b2727045da940d38122ed90c22ee871",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5fed7fc698ea4b8a94a5e3997d4f29bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'acc': 0.9857, 'loss': 0.043791117322444915}\n"
     ]
    }
   ],
   "source": [
    "optimizer.param_groups[0]['lr'] = lr_finder.lr_suggestion()\n",
    "\n",
    "trainer.run(trainloader, max_epochs=10)\n",
    "evaluator.run(testloader)\n",
    "print(evaluator.state.metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}